"""
% %
% FINITE-DIFFIRENCE PHASE-FIELD %    % CODE FOR SOLVING %
% CAHN-HILLIARD EQUATION %
% %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
"""
from datetime import datetime
import matplotlib.pyplot as plt
import imageio
from scipy import ndimage, misc
import numpy as np
import csv
import argparse, math

def save_out_file(etas, output_file, Nx, Ny):
    with open(output_file, 'w') as f:
        f.write(str(Nx)+','+str(Ny)+','+str(len(etas))+'\n')
        csv_writer = csv.writer(f)
        for eta in etas:
            for i in range(Nx):
                csv_writer.writerow(eta[i])
    f.close()


def save_out_n_file(all_etas, output_file, Nx, Ny, n_step):
    with open(output_file, 'w') as f:
        assert len(all_etas) == n_step
        f.write(str(Nx)+','+str(Ny)+','+str(len(all_etas[0]))+','+str(n_step)+'\n')
        csv_writer = csv.writer(f)
        for n in range(n_step):
            for eta in all_etas[n]:
                for i in range(Nx):
                    csv_writer.writerow(eta[i])
    f.close()


def dis(i1, j1, i2, j2):
    return math.sqrt((i1-i2)**2 + (j1-j2)**2)

def prepare_grain_v1(Nx, Ny, n_grain):
    assert n_grain == 2
    eta1 = np.zeros((Nx, Ny))
    eta2 = np.ones((Nx, Ny))
    for i in range(Nx):
        for j in range(Ny):
            if dis(i,j,Nx/2,Ny/2) < Nx/8:
                eta1[i,j] = 1.0
                eta2[i,j] = 0.0
    return [eta1, eta2]

def prepare_grain_v2(Nx, Ny, n_grain):
    assert n_grain == 2
    eta1 = np.zeros((Nx, Ny))
    eta2 = np.ones((Nx, Ny))
    for i in range(Nx):
        for j in range(Ny):
            if dis(i,j,Nx/2,Ny/2) < Nx/4:
                eta1[i,j] = 1.0
                eta2[i,j] = 0.0
            elif dis(i,j,Nx/2,Ny/2) < 3*Nx/8:
                eta1[i,j] = 0.5
                eta2[i,j] = 0.5
    return [eta1, eta2]

def prepare_grain_v3(Nx, Ny, n_grain):
    assert n_grain == 2
    eta1 = np.zeros((Nx, Ny))
    eta2 = np.ones((Nx, Ny))
    for i in range(Nx):
        for j in range(Ny):
            if i > Nx / 2 + Nx / 8 * math.sin(2*math.pi*20 / Ny * j):
                eta1[i,j] = 1.0
                eta2[i,j] = 0.0
                
    return [eta1, eta2]

def read_init_file(filename):
    with open(filename, 'r') as f:
        csv_reader = csv.reader(f)
        first_line = next(csv_reader)
        Nx = int(first_line[0])
        Ny = int(first_line[1])
        n_grain = int(first_line[2])
        etas = []
        for pg in range(n_grain):
            mtx = []
            for i in range(Nx):
                line_str = next(csv_reader)
                line = [float(itm) for itm in line_str]
                mtx.append(line)
            etas.append(np.array(mtx))
    f.close()
    return (Nx, Ny, n_grain, etas)

##### Main Program

parser = argparse.ArgumentParser(description='Generate Initial Condition for Grain Growth.')
parser.add_argument('--output', help='output file.')

args = parser.parse_args()

if args.output != None:
    output_file = args.output
else:
    output_file = "grain_growth_init.out"

#-- Simulation cell parameters:

Nx = 128
Ny = 128
n_grain = 2

etas = prepare_grain_v3(Nx, Ny, n_grain)

# save_out_file(etas, output_file, Nx, Ny)

dx = 0.5
dy = 0.5

A = 1.0
B = 1.0

L = 5.0     # the mobil in original code
kappa = 0.1  # the grcoef in original code

nprint = 100
dtime = 0.05
ttime = 0.0

NxNy = Nx*Ny

epsilon = 0.001

nstep = 1700

all_etas = []

all_etas.append([np.copy(etas[0]), np.copy(etas[1])])

for istep in range(1, nstep):
    ttime = ttime + dtime
    
    sum_eta_2 = etas[0]**2
    for i in range(1, n_grain):
        sum_eta_2 += etas[i]**2

    for i in range(0, n_grain):
        d_energy = -A*etas[i] + B*(etas[i]**3) + 2*etas[i]*(sum_eta_2 - etas[i]**2)
        lap_eta = (ndimage.laplace(etas[i],mode='nearest'))/(dx*dy)

        etas[i] -= dtime*L*(d_energy - kappa*lap_eta)

        etas[i] = np.where(etas[i]>=1.000,1.000,etas[i])
        etas[i] = np.where(etas[i]<0.00,0.00,etas[i])
    
    # save to all data
    all_etas.append([np.copy(etas[0]), np.copy(etas[1])])
    #---- print results
    if(((istep%nprint) == 0) or (istep == 1) ):
        print('Step: %5d'%istep)
        #--- write vtk file:
        #-- calculate total energy

        # energ = calculate_energ(Nx,Ny,con,kappa)
        # print("Energy: %8.6f"%energ)
        print("Zero elements in etas[0]: %10d"%(np.count_nonzero(etas[0]<epsilon)))
        print("Zero elements in etas[1]: %10d"%(np.count_nonzero(etas[1]<epsilon)))
        # print("Zero elements in lap_dF_dc: %10d\n\n"%(np.count_nonzero(lap_dF_dc<epsilon)))

# output_file = '../data/grain_growth_all_data'
# save_out_n_file(all_etas, output_file, Nx, Ny, nstep)
print(np.sum(all_etas[0][0]))
print(np.sum(all_etas[0][1]))
print(np.sum(all_etas[1500][0]))
print(np.sum(all_etas[1500][1]))
